{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# KIT-Loe-GE Cell Segmentation and Tracking\n",
        "\n",
        "\n",
        "Simultaneous cell segmentation and tracking method used for our submission as team KIT-Loe-GE to the [Cell Tracking Challenge](http://celltrackingchallenge.net/) in 2022.\n",
        "\n",
        "The code is publicly available at https://github.com/kaloeffler/EmbedTrack.\n",
        "\n",
        "----\n",
        "\n",
        "Publication:\n",
        "K. Löffler and M. Mikut (2022). EmbedTrack -- Simultaneous Cell Segmentation and Tracking Through Learning Offsets and Clustering Bandwidths. arXiv preprint. DOI: [10.48550/arXiv.2204.10713](https://doi.org/10.48550/arXiv.2204.10713)\n",
        "\n",
        "----\n"
      ],
      "metadata": {
        "id": "is-uQzzdqaYg"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1) Setting up the environment"
      ],
      "metadata": {
        "id": "maazx1Kl8hHM"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Creating the environment, cloning the code and adding some utilities for downloading the CTC data. Everything (data, code, trained models) will be stored in your personal google drive folder ('/content/drive/MyDrive') in a folder named \"EmbedTrack\", so you can access it later."
      ],
      "metadata": {
        "id": "LrJriGTywQv9"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OdXotXmQvvoO"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%cd drive/MyDrive/\n",
        "!git clone https://github.com/kaloeffler/EmbedTrack.git\n"
      ],
      "metadata": {
        "id": "hE3FvotewLju"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!conda --version"
      ],
      "metadata": {
        "id": "iqvp9UxvwW30"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q condacolab\n",
        "import condacolab\n",
        "condacolab.install()"
      ],
      "metadata": {
        "id": "fgoMCk_Bwy0m"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!conda env create -f /content/drive/MyDrive/EmbedTrack/environment.yml"
      ],
      "metadata": {
        "id": "KKJ7WJUpw3O9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%%shell\n",
        "eval \"$(conda shell.bash hook)\" # copy conda command to shell\n",
        "conda activate venv_embedtrack\n",
        "which python\n",
        "python --version\n"
      ],
      "metadata": {
        "id": "7Uy2FI-rxxBY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install imagecodecs --no-dependencies\n",
        "!pip install cffi==\"1.15.0\""
      ],
      "metadata": {
        "id": "v2FtmKdJgN-o"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Restarting the runtime\n",
        "get_ipython().kernel.do_shutdown(True)"
      ],
      "metadata": {
        "id": "AVs2_il5du8z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Check cuda is available - otherwise set in colab under \"runtime\" -> \"change runtime type\" the runtime from \"None\" to \"GPU\"**\n",
        "\n"
      ],
      "metadata": {
        "id": "kHEdAnuIPyq9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "print(torch.cuda.is_available())"
      ],
      "metadata": {
        "id": "IUjDlAlljZGp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Utilities to facilitate downloading data from the Cell Tracking Challenge**\n",
        "\n",
        "Please note: you need to run this cell before jumping to the training / inference sections!"
      ],
      "metadata": {
        "id": "hf9tQnviwSn5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import requests\n",
        "import zipfile\n",
        "import os\n",
        "\n",
        "def retrieve_ctc_data(url, save_dir):\n",
        "  zip_file = os.path.join(save_dir, url.split(\"/\")[-1])\n",
        "  with requests.get(url, stream=True) as req:\n",
        "    req.raise_for_status()\n",
        "    with open(zip_file, \"wb\") as file: \n",
        "      for chunk in req.iter_content(chunk_size=8192):\n",
        "        file.write(chunk)\n",
        "  print(f\"Unzip data set {os.path.basename(zip_file)}\")\n",
        "  with zipfile.ZipFile(zip_file) as z:\n",
        "    z.extractall(save_dir)\n",
        "  \n",
        "  os.remove(zip_file)\n",
        "      "
      ],
      "metadata": {
        "id": "DK-quV8_F9qM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2.) Training and Inference"
      ],
      "metadata": {
        "id": "RjdyDRbU8xtm"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 2.1.) Select a data set to do training / inference on\n",
        "\n",
        "EmbedTrack was tested and trained on the following 2D datasets, as they all provide an additional Silver Truth (ST) which will be processed together with the Gold Truth annotations (GT) to get fully labelled cell segmentation masks with resonable annotation quality: \n",
        "\"Fluo-N2DH-SIM+\",\n",
        "  \"Fluo-C2DL-MSC\",\n",
        "    \"Fluo-N2DH-GOWT1\",\n",
        "    \"PhC-C2DL-PSC\",\n",
        "    \"BF-C2DL-HSC\",\n",
        "    \"Fluo-N2DL-HeLa\",\n",
        "    \"BF-C2DL-MuSC\",\n",
        "    \"DIC-C2DH-HeLa\", and\n",
        "    \"PhC-C2DH-U373\"."
      ],
      "metadata": {
        "id": "o2bjJJmLTUMp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# possible data sets:\n",
        "\n",
        "#[    \"Fluo-N2DH-SIM+\",\n",
        "#    \"Fluo-C2DL-MSC\",\n",
        "#    \"Fluo-N2DH-GOWT1\",\n",
        "#    \"PhC-C2DL-PSC\",\n",
        "#    \"BF-C2DL-HSC\",\n",
        "#    \"Fluo-N2DL-HeLa\",\n",
        "#    \"BF-C2DL-MuSC\",\n",
        "#    \"DIC-C2DH-HeLa\",\n",
        "#    \"PhC-C2DH-U373\",\n",
        "#]\n",
        "\n",
        "data_set = \"Fluo-N2DH-SIM+\"\n"
      ],
      "metadata": {
        "id": "NUCDK7XcyV0C"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 2.2.) Download the selected data set from the Cell Tracking Challenge"
      ],
      "metadata": {
        "id": "unZRe2zW9NVq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# change to the embedtrack folder that has been created in your drive\n",
        "%cd /content/drive/MyDrive/EmbedTrack/\n",
        "!ls"
      ],
      "metadata": {
        "id": "QSgILgio0bOt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from pathlib import Path\n",
        "\n",
        "ctc_data_url = \"http://data.celltrackingchallenge.net\"\n",
        "ctc_metrics_url = \"http://public.celltrackingchallenge.net/software/EvaluationSoftware.zip\"\n",
        "\n",
        "training_data_url = os.path.join(ctc_data_url, \"training-datasets/\")\n",
        "challenge_data_url = os.path.join(ctc_data_url, \"challenge-datasets/\")\n",
        "\n",
        "current_path = Path.cwd()\n",
        "data_path = current_path / 'ctc_raw_data'\n",
        "ctc_metrics_path = os.path.join(current_path, \"embedtrack\", \"ctc_metrics\", \"CTC_eval\")\n",
        "\n",
        "# Download training data set\n",
        "if not os.path.exists(data_path / \"train\" / data_set):\n",
        "  dp = os.path.join(data_path, \"train\", data_set)\n",
        "  print(f\"Downloading training data set to {dp} ...\")\n",
        "  data_url = training_data_url + data_set + \".zip\"\n",
        "  retrieve_ctc_data(data_url, os.path.join(data_path, \"train\"))\n",
        "\n",
        "# Download challenge data set\n",
        "if not os.path.exists(data_path / \"challenge\" / data_set):\n",
        "  dp = os.path.join(data_path, \"challenge\", data_set)\n",
        "  print(f\"Downloading challenge data set to {dp} ...\")\n",
        "  data_url = challenge_data_url + data_set + \".zip\"\n",
        "  retrieve_ctc_data(data_url, os.path.join(data_path, \"challenge\"))\n",
        "\n",
        "# Download evaluation software\n",
        "if len(os.listdir(ctc_metrics_path)) <= 1:\n",
        "  print(f\"Downloading  ctc metrics to {ctc_metrics_path} ...\")\n",
        "  retrieve_ctc_data(ctc_metrics_url, ctc_metrics_path)\n",
        " \n",
        "# make CTC metrics executable\n",
        "!chmod -R 755 $ctc_metrics_path"
      ],
      "metadata": {
        "id": "7-j-ZXo20Ww-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 2.3.) Train a model for the selected data set"
      ],
      "metadata": {
        "id": "IM9SX9IY9VsK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# change to the embedtrack folder that has been created in your drive\n",
        "%cd /content/drive/MyDrive/EmbedTrack/\n",
        "!ls"
      ],
      "metadata": {
        "id": "qHXAwGBpoPst"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "matplotlib.use(\"Agg\")\n",
        "from embedtrack.train.run_training_pipeline import (\n",
        "    DataConfig,\n",
        "    ModelConfig,\n",
        "    TrainConfig,\n",
        "    run_pipeline,\n",
        ")\n",
        "import os\n",
        "from pathlib import Path\n",
        "\n",
        "# data configs\n",
        "\n",
        "PROJECT_PATH = Path.cwd()\n",
        "\n",
        "RAW_DATA_PATH = os.path.join(PROJECT_PATH, \"ctc_raw_data/train\")\n",
        "DATA_PATH_DEST = os.path.join(PROJECT_PATH, \"data\")\n",
        "MODEL_PATH = os.path.join(PROJECT_PATH, \"models\")\n",
        "\n",
        "USE_SILVER_TRUTH = True\n",
        "TRAIN_VAL_SEQUNCES = [\"01\", \"02\"]\n",
        "TRAIN_VAL_SPLIT = 0.1\n",
        "\n",
        "N_EPOCHS = 15\n",
        "# Adam optimizer; normalize images; OneCycle LR sheduler; N epochs\n",
        "MODEL_NAME = \"adam_norm_onecycle_\" + str(N_EPOCHS)\n",
        "\n",
        "if data_set == \"Fluo-N2DH-SIM+\":\n",
        "    use_silver_truth = False\n",
        "else:\n",
        "    use_silver_truth = USE_SILVER_TRUTH\n",
        "\n",
        "data_config = DataConfig(\n",
        "    RAW_DATA_PATH,\n",
        "    data_set,\n",
        "    DATA_PATH_DEST,\n",
        "    use_silver_truth=use_silver_truth,\n",
        "    train_val_sequences=TRAIN_VAL_SEQUNCES,\n",
        "    train_val_split=TRAIN_VAL_SPLIT,\n",
        ")\n",
        "\n",
        "# train configs\n",
        "MODEL_SAVE_DIR = os.path.join(\n",
        "    MODEL_PATH,\n",
        "    data_set,\n",
        "    MODEL_NAME,\n",
        ")\n",
        "if data_set != \"Fluo-C2DL-MSC\":\n",
        "    CROP_SIZE = 256\n",
        "    TRAIN_BATCH_SIZE = 16\n",
        "    VAL_BATCH_SIZE = 16\n",
        "    DISPLAY_IT = 1000\n",
        "\n",
        "else:\n",
        "    CROP_SIZE = 512\n",
        "    TRAIN_BATCH_SIZE = 8\n",
        "    VAL_BATCH_SIZE = 8\n",
        "    DISPLAY_IT = 200\n",
        "\n",
        "CENTER = \"medoid\"  \n",
        "RESUME_TRAINING = False\n",
        "TRAIN_SIZE = None  # if None training on full train data set; otherwise still training on full data set but only use a fraction of the data per epoch\n",
        "VAL_SIZE = None  # if None validation on full val data set; otherwise still val on full data set but only use a fraction of the data per epoch\n",
        "VIRTUAL_TRAIN_BATCH_MULTIPLIER = 1\n",
        "VIRTUAL_VAL_BATCH_MULTIPLIER = 1\n",
        "DISPLAY = False\n",
        "\n",
        "train_config = TrainConfig(\n",
        "    MODEL_SAVE_DIR,\n",
        "    crop_size=CROP_SIZE,\n",
        "    center=CENTER,\n",
        "    resume_training=RESUME_TRAINING,\n",
        "    train_size=TRAIN_SIZE,\n",
        "    train_batch_size=TRAIN_BATCH_SIZE,\n",
        "    virtual_train_batch_multiplier=VIRTUAL_TRAIN_BATCH_MULTIPLIER,\n",
        "    val_size=VAL_SIZE,\n",
        "    val_batch_size=VAL_BATCH_SIZE,\n",
        "    virtual_val_batch_multiplier=VIRTUAL_VAL_BATCH_MULTIPLIER,\n",
        "    n_epochs=N_EPOCHS,\n",
        "    display=DISPLAY,\n",
        "    display_it=DISPLAY_IT,\n",
        ")\n",
        "\n",
        "# model config\n",
        "INPUT_CHANNELS = 1\n",
        "N_SEG_CLASSES = [4, 1]\n",
        "N_TRACK_CLASSES = 2\n",
        "\n",
        "model_config = ModelConfig(INPUT_CHANNELS, N_SEG_CLASSES, N_TRACK_CLASSES)\n",
        "\n",
        "run_pipeline(data_config, train_config, model_config)\n",
        "plt.close(\"all\")\n"
      ],
      "metadata": {
        "id": "OPnQmrb3yLQf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 2.4.) Inference using the just trained model"
      ],
      "metadata": {
        "id": "c-4C4rUrqck5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# change to the embedtrack folder that has been created in your drive\n",
        "%cd /content/drive/MyDrive/EmbedTrack/\n",
        "!ls"
      ],
      "metadata": {
        "id": "rhRTT_SZ5nKw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# make CTC metrics executable\n",
        "current_path = Path.cwd()\n",
        "ctc_metrics_path = os.path.join(current_path, \"embedtrack\", \"ctc_metrics\", \"CTC_eval\")\n",
        "!chmod -R 755 $ctc_metrics_path"
      ],
      "metadata": {
        "id": "T9T4UvbwjL9H"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from datetime import datetime\n",
        "from pathlib import Path\n",
        "from time import time\n",
        "import shutil\n",
        "from embedtrack.ctc_metrics.eval_ctc import calc_ctc_scores\n",
        "from embedtrack.infer.infer_ctc_data import inference\n",
        "\n",
        "PROJECT_PATH = Path.cwd()\n",
        "\n",
        "RAW_DATA_PATHS = [os.path.join(PROJECT_PATH, \"ctc_raw_data/challenge\"),\n",
        "                  os.path.join(PROJECT_PATH, \"ctc_raw_data/train\")]\n",
        "MODEL_PATH = os.path.join(PROJECT_PATH, \"models\")\n",
        "RES_PATH = os.path.join(PROJECT_PATH, \"results\")\n",
        "\n",
        "# Adam optimizer; normalize images; OneCycle LR sheduler; N epochs\n",
        "MODEL_NAME = \"adam_norm_onecycle_15\"\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "for raw_data_path in RAW_DATA_PATHS:\n",
        "      for data_id in [\"01\", \"02\"]:\n",
        "          img_path = os.path.join(raw_data_path, data_set, data_id)\n",
        "\n",
        "          model_dir = os.path.join(MODEL_PATH, data_set, MODEL_NAME)\n",
        "          if not os.path.exists(model_dir):\n",
        "              print(f\"no trained model for data set {data_set}\")\n",
        "              continue\n",
        "\n",
        "          # time stamps\n",
        "          timestamps_trained_models = [\n",
        "              datetime.strptime(time_stamp, \"%Y-%m-%d---%H-%M-%S\")\n",
        "              for time_stamp in os.listdir(model_dir)\n",
        "          ]\n",
        "          timestamps_trained_models.sort()\n",
        "          last_model = timestamps_trained_models[-1].strftime(\"%Y-%m-%d---%H-%M-%S\")\n",
        "          model_path = os.path.join(model_dir, last_model, \"best_iou_model.pth\")\n",
        "          config_file = os.path.join(model_dir, last_model, \"config.json\")\n",
        "          t_start = time()\n",
        "          inference(img_path, model_path, config_file, batch_size=BATCH_SIZE)\n",
        "          t_end = time()\n",
        "\n",
        "          run_time = t_end - t_start\n",
        "          print(f\"Image sequence: {img_path}\")\n",
        "          print(f\"Inference Time {img_path}: {run_time}s\")\n",
        "\n",
        "          res_path = os.path.join(RES_PATH, data_set, MODEL_NAME, last_model, os.path.basename(raw_data_path), data_id+\"_RES\")\n",
        "          if not os.path.exists(os.path.dirname(res_path)):\n",
        "            os.makedirs(os.path.dirname(res_path))\n",
        "          shutil.move(img_path+\"_RES\", res_path)\n",
        "          if os.path.basename(raw_data_path) == \"train\":\n",
        "            metrics = calc_ctc_scores(Path(res_path), Path(img_path+\"_GT\"))\n",
        "            print(metrics)\n",
        "\n"
      ],
      "metadata": {
        "id": "v62G_sR45aC1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3.) Inference using the models submitted to the CTC\n",
        "\n",
        "Download the trained models submitted to the CTC and use them for inference."
      ],
      "metadata": {
        "id": "SScwJ45yqU24"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# select data set to do inference on\n",
        "# possible data sets:\n",
        "#[    \"Fluo-N2DH-SIM+\",\n",
        "#    \"Fluo-C2DL-MSC\",\n",
        "#    \"Fluo-N2DH-GOWT1\",\n",
        "#    \"PhC-C2DL-PSC\",\n",
        "#    \"BF-C2DL-HSC\",\n",
        "#    \"Fluo-N2DL-HeLa\",\n",
        "#    \"BF-C2DL-MuSC\",\n",
        "#    \"DIC-C2DH-HeLa\",\n",
        "#    \"PhC-C2DH-U373\",\n",
        "#]\n",
        "\n",
        "data_set = \"Fluo-N2DH-SIM+\""
      ],
      "metadata": {
        "id": "Jho1KTBq2NTy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# change to the embedtrack folder that has been created in your drive\n",
        "%cd /content/drive/MyDrive/EmbedTrack/\n",
        "!ls"
      ],
      "metadata": {
        "id": "vxzHPXIa4CD_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from pathlib import Path\n",
        "executables_url = \"http://public.celltrackingchallenge.net/participants/KIT-Loe-GE.zip\"\n",
        "executables_path = Path.cwd()\n",
        "# Download trained models and executables submitted to the CTC\n",
        "if not os.path.exists(os.path.join(executables_path, \"KIT-Loe-GE\")):\n",
        "  dp = os.path.join(executables_path, \"KIT-Loe-GE\")\n",
        "  print(f\"Downloading trained models and excetuables of KIT-Loe-GE to {dp} ...\")\n",
        "  retrieve_ctc_data(executables_url, executables_path)"
      ],
      "metadata": {
        "id": "OBv4GQvOqaFU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from pathlib import Path\n",
        "\n",
        "ctc_data_url = \"http://data.celltrackingchallenge.net\"\n",
        "ctc_metrics_url = \"http://public.celltrackingchallenge.net/software/EvaluationSoftware.zip\"\n",
        "\n",
        "training_data_url = os.path.join(ctc_data_url, \"training-datasets/\")\n",
        "challenge_data_url = os.path.join(ctc_data_url, \"challenge-datasets/\")\n",
        "\n",
        "current_path = Path.cwd()\n",
        "data_path = current_path / 'ctc_raw_data'\n",
        "ctc_metrics_path = os.path.join(current_path, \"embedtrack\", \"ctc_metrics\", \"CTC_eval\")\n",
        "\n",
        "# Download training data set\n",
        "if not os.path.exists(data_path / \"train\" / data_set):\n",
        "  dp = os.path.join(data_path, \"train\", data_set)\n",
        "  print(f\"Downloading training data set to {dp} ...\")\n",
        "  data_url = training_data_url + data_set + \".zip\"\n",
        "  retrieve_ctc_data(data_url, os.path.join(data_path, \"train\"))\n",
        "\n",
        "# Download challenge data set\n",
        "if not os.path.exists(data_path / \"challenge\" / data_set):\n",
        "  dp = os.path.join(data_path, \"challenge\", data_set)\n",
        "  print(f\"Downloading challenge data set to {dp} ...\")\n",
        "  data_url = challenge_data_url + data_set + \".zip\"\n",
        "  retrieve_ctc_data(data_url, os.path.join(data_path, \"challenge\"))\n",
        "\n",
        "# Download evaluation software\n",
        "if len(os.listdir(ctc_metrics_path)) <= 1:\n",
        "  print(f\"Downloading  ctc metrics to {ctc_metrics_path} ...\")\n",
        "  retrieve_ctc_data(ctc_metrics_url, ctc_metrics_path)"
      ],
      "metadata": {
        "id": "jc5-YpnGOlAp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# make CTC metrics executable\n",
        "current_path = Path.cwd()\n",
        "ctc_metrics_path = os.path.join(current_path, \"embedtrack\", \"ctc_metrics\", \"CTC_eval\")\n",
        "!chmod -R 755 $ctc_metrics_path"
      ],
      "metadata": {
        "id": "S1i_FCqxjWa1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "from pathlib import Path\n",
        "from time import time\n",
        "import shutil\n",
        "from embedtrack.ctc_metrics.eval_ctc import calc_ctc_scores\n",
        "from embedtrack.infer.infer_ctc_data import inference\n",
        "\n",
        "PROJECT_PATH = \"/content/drive/MyDrive/EmbedTrack/\"\n",
        "\n",
        "RAW_DATA_PATHS = [os.path.join(PROJECT_PATH, \"ctc_raw_data/challenge\"),\n",
        "                  os.path.join(PROJECT_PATH, \"ctc_raw_data/train\")]\n",
        "MODEL_PATH = os.path.join(PROJECT_PATH, \"KIT-Loe-GE\", \"models\")\n",
        "RES_PATH = os.path.join(PROJECT_PATH, \"results\")\n",
        "\n",
        "BATCH_SIZE = 32\n",
        "for raw_data_path in RAW_DATA_PATHS:\n",
        "      for data_id in [\"01\", \"02\"]:\n",
        "          img_path = os.path.join(raw_data_path, data_set, data_id)\n",
        "\n",
        "          model_dir = os.path.join(MODEL_PATH, data_set)\n",
        "          if not os.path.exists(model_dir):\n",
        "              print(f\"no trained model for data set {data_set}\")\n",
        "              continue\n",
        "          \n",
        "          model_path = os.path.join(model_dir, \"best_iou_model.pth\")\n",
        "          config_file = os.path.join(model_dir, \"config.json\")\n",
        "          t_start = time()\n",
        "          inference(img_path, model_path, config_file, batch_size=BATCH_SIZE)\n",
        "          t_end = time()\n",
        "\n",
        "          run_time = t_end - t_start\n",
        "          print(f\"Image sequence: {img_path}\")\n",
        "          print(f\"Inference Time {img_path}: {run_time}s\")\n",
        "\n",
        "          res_path = os.path.join(RES_PATH, data_set, \"KIT-Loe-GE\", os.path.basename(raw_data_path), data_id+\"_RES\")\n",
        "          if not os.path.exists(os.path.dirname(res_path)):\n",
        "            os.makedirs(os.path.dirname(res_path))\n",
        "          shutil.move(img_path+\"_RES\", res_path)\n",
        "          if os.path.basename(raw_data_path) == \"train\":\n",
        "            metrics = calc_ctc_scores(Path(res_path), Path(img_path+\"_GT\"))\n",
        "            print(metrics)\n",
        "\n"
      ],
      "metadata": {
        "id": "PC-SJObj1Csp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "kCDhn7AqEUcw"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}